{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated\n",
    "\n",
    "from typing_extensions import TypedDict\n",
    "\n",
    "from langgraph.graph import StateGraph, START, END\n",
    "from langgraph.graph.message import add_messages\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "memory = MemorySaver()\n",
    "\n",
    "class State(TypedDict):\n",
    "    # Messages have the type \"list\". The `add_messages` function\n",
    "    # in the annotation defines how this state key should be updated\n",
    "    # (in this case, it appends messages to the list, rather than overwriting them)\n",
    "    messages: Annotated[list, add_messages]\n",
    "\n",
    "\n",
    "graph_builder = StateGraph(State)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.llms.tongyi import Tongyi\n",
    "\n",
    "llm = Tongyi(model=\"qwen-turbo\")\n",
    "\n",
    "def chatbot(state: State):\n",
    "    return {\"messages\": [llm.invoke(state[\"messages\"])]}\n",
    "\n",
    "\n",
    "# The first argument is the unique node name\n",
    "# The second argument is the function or object that will be called whenever\n",
    "# the node is used.\n",
    "graph_builder.add_node(\"chatbot\", chatbot)\n",
    "graph_builder.add_edge(START, \"chatbot\")\n",
    "graph_builder.add_edge(\"chatbot\", END)\n",
    "graph = graph_builder.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Hi there! My name is Will.\n",
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Hello Will! It's nice to meet you. How can I assist you today?\n"
     ]
    }
   ],
   "source": [
    "# 聊天 ID\n",
    "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
    "\n",
    "user_input = \"Hi there! My name is Will.\"\n",
    "\n",
    "# The config is the **second positional argument** to stream() or invoke()!\n",
    "events = graph.stream(\n",
    "    {\"messages\": [(\"user\", user_input)]}, config, stream_mode=\"values\"\n",
    ")\n",
    "for event in events:\n",
    "    event[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Remember my name?\n",
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Of course, Will! I remember your name. How can I assist you today?\n"
     ]
    }
   ],
   "source": [
    "user_input = \"Remember my name?\"\n",
    "\n",
    "# The config is the **second positional argument** to stream() or invoke()!\n",
    "events = graph.stream(\n",
    "    {\"messages\": [(\"user\", user_input)]}, config, stream_mode=\"values\"\n",
    ")\n",
    "for event in events:\n",
    "    event[\"messages\"][-1].pretty_print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Remember my name?\n",
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "Of course! Could you please remind me of your name? I want to make sure I have the correct information.\n"
     ]
    }
   ],
   "source": [
    "# The only difference is we change the `thread_id` here to \"2\" instead of \"1\"\n",
    "events = graph.stream(\n",
    "    {\"messages\": [(\"user\", user_input)]},\n",
    "    {\"configurable\": {\"thread_id\": \"2\"}},\n",
    "    stream_mode=\"values\",\n",
    ")\n",
    "for event in events:\n",
    "    event[\"messages\"][-1].pretty_print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Assistant: Sure, you can call me J! How can I assist you today, J? 😊\n",
      "Assistant: My name is Qwen. I am a large language model created by Alibaba Cloud.\n",
      "Goodbye!\n"
     ]
    }
   ],
   "source": [
    "def stream_graph_updates(user_input: str):\n",
    "    for event in graph.stream({\"messages\": [(\"user\", user_input)]}):\n",
    "        for value in event.values():\n",
    "            print(\"Assistant:\", value[\"messages\"][-1])\n",
    "\n",
    "\n",
    "while True:\n",
    "    try:\n",
    "        user_input = input(\"User: \")\n",
    "        if user_input.lower() in [\"quit\", \"exit\", \"q\"]:\n",
    "            print(\"Goodbye!\")\n",
    "            break\n",
    "\n",
    "        stream_graph_updates(user_input)\n",
    "    except:\n",
    "        # fallback if input() is not available\n",
    "        user_input = \"What do you know about LangGraph?\"\n",
    "        print(\"User: \" + user_input)\n",
    "        stream_graph_updates(user_input)\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
